{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6e879e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../benchmark') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d566cc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import yaml\n",
    "import benchmark\n",
    "import utils\n",
    "from openai_cache import Completion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94e6f62d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_summarization_prompt(objects, receptacles, placements):\n",
    "    summarization_prompt_template = '''objects = [\"dried figs\", \"protein bar\", \"cornmeal\", \"Macadamia nuts\", \"vinegar\", \"herbal tea\", \"peanut oil\", \"chocolate bar\", \"bread crumbs\", \"Folgers instant coffee\"]\n",
    "receptacles = [\"top rack\", \"middle rack\", \"table\", \"shelf\", \"plastic box\"]\n",
    "pick_and_place(\"dried figs\", \"plastic box\")\n",
    "pick_and_place(\"protein bar\", \"shelf\")\n",
    "pick_and_place(\"cornmeal\", \"top rack\")\n",
    "pick_and_place(\"Macadamia nuts\", \"plastic box\")\n",
    "pick_and_place(\"vinegar\", \"middle rack\")\n",
    "pick_and_place(\"herbal tea\", \"table\")\n",
    "pick_and_place(\"peanut oil\", \"middle rack\")\n",
    "pick_and_place(\"chocolate bar\", \"shelf\")\n",
    "pick_and_place(\"bread crumbs\", \"top rack\")\n",
    "pick_and_place(\"Folgers instant coffee\", \"table\")\n",
    "# Summary: Put dry ingredients on the top rack, liquid ingredients in the middle rack, tea and coffee on the table, packaged snacks on the shelf, and dried fruits and nuts in the plastic box.\n",
    "\n",
    "objects = [\"yoga pants\", \"wool sweater\", \"black jeans\", \"Nike shorts\"]\n",
    "receptacles = [\"hamper\", \"bed\"]\n",
    "pick_and_place(\"yoga pants\", \"hamper\")\n",
    "pick_and_place(\"wool sweater\", \"bed\")\n",
    "pick_and_place(\"black jeans\", \"bed\")\n",
    "pick_and_place(\"Nike shorts\", \"hamper\")\n",
    "# Summary: Put athletic clothes in the hamper and other clothes on the bed.\n",
    "\n",
    "objects = [\"Nike sweatpants\", \"sweater\", \"cargo shorts\", \"iPhone\", \"dictionary\", \"tablet\", \"Under Armour t-shirt\", \"physics homework\"]\n",
    "receptacles = [\"backpack\", \"closet\", \"desk\", \"nightstand\"]\n",
    "pick_and_place(\"Nike sweatpants\", \"backpack\")\n",
    "pick_and_place(\"sweater\", \"closet\")\n",
    "pick_and_place(\"cargo shorts\", \"closet\")\n",
    "pick_and_place(\"iPhone\", \"nightstand\")\n",
    "pick_and_place(\"dictionary\", \"desk\")\n",
    "pick_and_place(\"tablet\", \"nightstand\")\n",
    "pick_and_place(\"Under Armour t-shirt\", \"backpack\")\n",
    "pick_and_place(\"physics homework\", \"desk\")\n",
    "# Summary: Put workout clothes in the backpack, other clothes in the closet, books and homeworks on the desk, and electronics on the nightstand.\n",
    "\n",
    "objects = {objects_str}\n",
    "receptacles = {receptacles_str}\n",
    "{placements_str}\n",
    "# Summary:'''\n",
    "    objects_str = '[' + ', '.join(map(lambda x: f'\"{x}\"', objects)) + ']'\n",
    "    receptacles_str = '[' + ', '.join(map(lambda x: f'\"{x}\"', receptacles)) + ']'\n",
    "    placements_str = '\\n'.join(map(lambda x: f'pick_and_place(\"{x[0]}\", \"{x[1]}\")', placements))\n",
    "    return summarization_prompt_template.format(objects_str=objects_str, receptacles_str=receptacles_str, placements_str=placements_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdaedc66",
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_placement_prompt(summary, objects, receptacles):\n",
    "    placement_prompt_template = '''# Summary: Put clothes in the laundry basket and toys in the storage box.\n",
    "objects = [\"socks\", \"toy car\", \"shirt\", \"Lego brick\"]\n",
    "receptacles = [\"laundry basket\", \"storage box\"]\n",
    "pick_and_place(\"socks\", \"laundry basket\")\n",
    "pick_and_place(\"toy car\", \"storage box\")\n",
    "pick_and_place(\"shirt\", \"laundry basket\")\n",
    "pick_and_place(\"Lego brick\", \"storage box\")\n",
    "\n",
    "# Summary: {summary}\n",
    "objects = {objects_str}\n",
    "receptacles = {receptacles_str}\n",
    "pick_and_place(\"{first_object}\",'''\n",
    "    objects_str = '[' + ', '.join(map(lambda x: f'\"{x}\"', objects)) + ']'\n",
    "    receptacles_str = '[' + ', '.join(map(lambda x: f'\"{x}\"', receptacles)) + ']'\n",
    "    return placement_prompt_template.format(summary=summary, objects_str=objects_str, receptacles_str=receptacles_str, first_object=objects[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42299398",
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_category_prompt(summary):\n",
    "    return '''# Summary: Put shirts on the bed, jackets and pants on the chair, and bags on the shelf.\n",
    "objects = [\"shirt\", \"jacket or pants\", \"bag\"]\n",
    "\n",
    "# Summary: Put pillows on the sofa, clothes on the chair, and shoes on the rack.\n",
    "objects = [\"pillow\", \"clothing\", \"shoe\"]\n",
    "\n",
    "# Summary: {summary}\n",
    "objects = [\"'''.format(summary=summary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "482175dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_summarization_primitive_prompt(objects, primitives):\n",
    "    summarization_primitive_prompt_template = '''objects = [\"granola bar\", \"hat\", \"toy car\", \"Lego brick\", \"fruit snacks\", \"shirt\"]\n",
    "pick_and_toss(\"granola bar\")\n",
    "pick_and_place(\"hat\")\n",
    "pick_and_place(\"toy car\")\n",
    "pick_and_place(\"Lego brick\")\n",
    "pick_and_toss(\"fruit snacks\")\n",
    "pick_and_place(\"shirt\")\n",
    "# Summary: Pick and place clothes and toys, pick and toss snacks.\n",
    "\n",
    "objects = {objects_str}\n",
    "{primitives_str}\n",
    "# Summary:'''\n",
    "    objects_str = '[' + ', '.join(map(lambda x: f'\"{x}\"', objects)) + ']'\n",
    "    primitives_str = '\\n'.join(map(lambda x: f'pick_and_{x[1]}(\"{x[0]}\")', primitives))\n",
    "    return summarization_primitive_prompt_template.format(objects_str=objects_str, primitives_str=primitives_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55523281",
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_primitive_prompt(summary, objects):\n",
    "    primitive_prompt_template = '''# Summary: Pick and place clothes, pick and toss snacks.\n",
    "objects = [\"granola bar\", \"hat\", \"toy car\", \"Lego brick\", \"fruit snacks\", \"shirt\"]\n",
    "pick_and_toss(\"granola bar\")\n",
    "pick_and_place(\"hat\")\n",
    "pick_and_place(\"toy car\")\n",
    "pick_and_place(\"Lego brick\")\n",
    "pick_and_toss(\"fruit snacks\")\n",
    "pick_and_place(\"shirt\")\n",
    "\n",
    "# Summary: Pick and place granola bars, hats, toy cars, and Lego bricks, pick and toss fruit snacks and shirts.\n",
    "objects = [\"clothing\", \"snack\"]\n",
    "pick_and_place(\"clothing\")\n",
    "pick_and_toss(\"snack\")\n",
    "\n",
    "# Summary: {summary}\n",
    "objects = {objects_str}'''\n",
    "    objects_str = '[' + ', '.join(map(lambda x: f'\"{x}\"', objects)) + ']'\n",
    "    return primitive_prompt_template.format(summary=summary, objects_str=objects_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bdacdfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_categories(categories_completion):\n",
    "    categories = categories_completion.split(',')\n",
    "    categories[-1] = categories[-1].replace(']', '')\n",
    "    categories = [c.strip().replace('\"', '') for c in categories]\n",
    "    return categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b31d38a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_primitives(primitive_completion):\n",
    "    primitives = []\n",
    "    for line in primitive_completion.strip().split('\\n'):\n",
    "        if len(line) == 0:\n",
    "            print('Warning: Stopping since newline was encountered')\n",
    "            break\n",
    "        primitive, obj = line.split('(')\n",
    "        primitive = primitive.strip().replace('pick_and_', '')\n",
    "        obj = obj.strip().replace(')', '').replace('\"', '')\n",
    "        primitives.append([obj, primitive])\n",
    "    return primitives"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a94efaa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_scenario(scenario_name):\n",
    "    with open(f'scenarios/{scenario_name}.yml', 'r', encoding='utf8') as f:\n",
    "        scenario_dict = yaml.safe_load(f)\n",
    "    scenario = benchmark.Scenario(\n",
    "        room=None,\n",
    "        receptacles=[r for r, info in scenario_dict['receptacles'].items() if 'primitive_names' in info],  # Ignore receptacles that are only used as a prop in this scenario (such as the sofa)\n",
    "        seen_objects=scenario_dict['seen_objects'],\n",
    "        seen_placements=scenario_dict['seen_placements'],\n",
    "        unseen_objects=scenario_dict['unseen_objects'],\n",
    "        unseen_placements=scenario_dict['unseen_placements'],\n",
    "        annotator_notes=scenario_dict['annotator_notes'],\n",
    "        tags=None)\n",
    "    scenario.seen_primitives = scenario_dict['seen_primitives']\n",
    "\n",
    "    # Sanity checks\n",
    "    assert len(set(scenario.seen_objects)) == len(scenario.seen_objects)\n",
    "    for obj, recep in scenario.seen_placements:\n",
    "        assert obj in scenario.seen_objects\n",
    "        assert recep in scenario.receptacles\n",
    "    for obj, primitive in scenario.seen_primitives:\n",
    "        assert obj in scenario.seen_objects\n",
    "        assert primitive in ['place', 'toss']\n",
    "    for obj1, (obj2, _) in zip(scenario.seen_objects, scenario.seen_placements):\n",
    "        assert obj1 == obj2\n",
    "    for obj1, (obj2, _) in zip(scenario.seen_objects, scenario.seen_primitives):\n",
    "        assert obj1 == obj2\n",
    "    assert len(set(scenario.unseen_objects)) == len(scenario.unseen_objects)\n",
    "    for obj, recep in scenario.unseen_placements:\n",
    "        assert obj in scenario.unseen_objects\n",
    "        assert recep in scenario.receptacles\n",
    "    for obj1, (obj2, _) in zip(scenario.unseen_objects, scenario.unseen_placements):\n",
    "        assert obj1 == obj2\n",
    "\n",
    "    return scenario"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdb301f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_scenario(scenario_name, verbose=False):\n",
    "    completion = Completion()\n",
    "    scenario = load_scenario(scenario_name)\n",
    "    if verbose:\n",
    "        print(f'Scenario: {scenario_name}\\n')\n",
    "\n",
    "    # Summarization\n",
    "    summarization_prompt = construct_summarization_prompt(\n",
    "        scenario.seen_objects, scenario.receptacles, scenario.seen_placements)\n",
    "    summarization_completion = completion.create(summarization_prompt)['choices'][0]['text']\n",
    "    if verbose:\n",
    "        print(summarization_prompt, end='')\n",
    "        utils.print_colored(summarization_completion, 'blue')\n",
    "        print('\\n' + 10 * '-' + '\\n')\n",
    "\n",
    "    # Seen object placement\n",
    "    summary = benchmark.parse_summary(summarization_completion)\n",
    "    placement_prompt = construct_placement_prompt(summary, scenario.seen_objects, scenario.receptacles)\n",
    "    placement_completion = completion.create(placement_prompt)['choices'][0]['text']\n",
    "    if verbose:\n",
    "        print(placement_prompt, end='')\n",
    "        utils.print_colored(placement_completion, 'blue')\n",
    "        print('\\n' + 10 * '-' + '\\n')\n",
    "\n",
    "    # Object categories\n",
    "    category_prompt = construct_category_prompt(summary)\n",
    "    category_completion = completion.create(category_prompt)['choices'][0]['text']\n",
    "    categories = parse_categories(category_completion)\n",
    "    if verbose:\n",
    "        print(category_prompt, end='')\n",
    "        utils.print_colored(category_completion, 'blue')\n",
    "        print('\\n' + 10 * '-' + '\\n')\n",
    "\n",
    "    # Object category placement\n",
    "    category_placement_prompt = construct_placement_prompt(summary, categories, scenario.receptacles)\n",
    "    category_placement_completion = completion.create(category_placement_prompt)['choices'][0]['text']\n",
    "    category_placements = benchmark.parse_placements(category_placement_completion, categories)\n",
    "    if verbose:\n",
    "        print(category_placement_prompt, end='')\n",
    "        utils.print_colored(category_placement_completion, 'blue')\n",
    "        print('\\n' + 10 * '-' + '\\n')\n",
    "\n",
    "    # Summarization for primitive selection\n",
    "    summarization_primitive_prompt = construct_summarization_primitive_prompt(\n",
    "        scenario.seen_objects, scenario.seen_primitives)\n",
    "    summarization_primitive_completion = completion.create(summarization_primitive_prompt)['choices'][0]['text']\n",
    "    if verbose:\n",
    "        print(summarization_primitive_prompt, end='')\n",
    "        utils.print_colored(summarization_primitive_completion, 'blue')\n",
    "        print('\\n' + 10 * '-' + '\\n')\n",
    "\n",
    "    # Seen object primitive selection\n",
    "    summary_primitive = benchmark.parse_summary(summarization_primitive_completion)\n",
    "    primitive_prompt = construct_primitive_prompt(summary_primitive, scenario.seen_objects)\n",
    "    primitive_completion = completion.create(primitive_prompt)['choices'][0]['text']\n",
    "    if verbose:\n",
    "        print(primitive_prompt, end='')\n",
    "        utils.print_colored(primitive_completion, 'blue')\n",
    "        print('\\n' + 10 * '-' + '\\n')\n",
    "\n",
    "    # Object category primitive selection\n",
    "    category_primitive_prompt = construct_primitive_prompt(summary_primitive, categories)\n",
    "    category_primitive_completion = completion.create(category_primitive_prompt)['choices'][0]['text']\n",
    "    category_primitives = parse_primitives(category_primitive_completion)\n",
    "    if verbose:\n",
    "        print(category_primitive_prompt, end='')\n",
    "        utils.print_colored(category_primitive_completion, 'blue')\n",
    "        print('\\n' + 10 * '-' + '\\n')\n",
    "\n",
    "    # Analysis\n",
    "    predicted_placements = benchmark.parse_placements(placement_completion, scenario.seen_objects)\n",
    "    placement_corrects, placement_accuracy = benchmark.check_placements(predicted_placements, scenario.seen_placements)\n",
    "    predicted_primitives = parse_primitives(primitive_completion)\n",
    "    primitive_corrects, primitive_accuracy = benchmark.check_placements(predicted_primitives, scenario.seen_primitives)\n",
    "    if verbose:\n",
    "        print(f'Annotator notes: {scenario.annotator_notes}\\n')\n",
    "\n",
    "        print('\\nSeen placements:')\n",
    "        for placement in scenario.seen_placements:\n",
    "            print(placement)\n",
    "        print('\\nParsed placements:')\n",
    "        for placement, correct in zip(predicted_placements, placement_corrects):\n",
    "            utils.print_colored(placement, 'green' if correct else 'red')\n",
    "        print(f'\\nSeen placement accuracy: {placement_accuracy:.2f}')\n",
    "\n",
    "        print('\\nSeen primitives:')\n",
    "        for primitive in scenario.seen_primitives:\n",
    "            print(primitive)\n",
    "        print('\\nParsed primitives:')\n",
    "        for primitive, correct in zip(predicted_primitives, primitive_corrects):\n",
    "            utils.print_colored(primitive, 'green' if correct else 'red')\n",
    "        print(f'\\nSeen primitive accuracy: {primitive_accuracy:.2f}')\n",
    "\n",
    "        print('\\nCategories:', categories)\n",
    "        print('\\nCategory placements:')\n",
    "        for placement in category_placements:\n",
    "            print(placement)\n",
    "        print('\\nCategory primitives:')\n",
    "        for primitive in category_primitives:\n",
    "            print(primitive)\n",
    "        print('\\n' + 80 * '-' + '\\n')\n",
    "\n",
    "    # YAML\n",
    "    preferences = {\n",
    "        'categories': categories,\n",
    "        'placements': dict(category_placements),\n",
    "        'primitives': dict(category_primitives),\n",
    "    }\n",
    "    with open(f'preferences/{scenario_name}.yml', 'w', encoding='utf8') as f:\n",
    "        yaml.dump(preferences, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f0308e9",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "#process_scenario('test', verbose=True)\n",
    "process_scenario('test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "306a4863",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for scenario_idx in range(8):\n",
    "    #process_scenario(f'scenario-{scenario_idx + 1:02}', verbose=True)\n",
    "    process_scenario(f'scenario-{scenario_idx + 1:02}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fae4dedf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
